Install required packages to access images from S3 storage

!pip install -U imagecodecs s3fs tifffile

!git clone https://github.com/jump-cellpainting/JUMP-Target !git clone https://github.com/jump-cellpainting/datasets.git

Check the different data sources available in the CP-JUMP database

import pandas as pd
jump_plates_metadata = pd.read_csv("datasets/metadata/plate.csv.gz")
jump_plates_metadata
Metadata_Source Metadata_Batch Metadata_Plate Metadata_PlateType
0 source_1 Batch1_20221004 UL000109 COMPOUND_EMPTY
1 source_1 Batch1_20221004 UL001641 COMPOUND
2 source_1 Batch1_20221004 UL001643 COMPOUND
3 source_1 Batch1_20221004 UL001645 COMPOUND
4 source_1 Batch1_20221004 UL001651 COMPOUND
... ... ... ... ...
2520 source_9 20211103-Run16 GR00004417 COMPOUND
2521 source_9 20211103-Run16 GR00004418 COMPOUND
2522 source_9 20211103-Run16 GR00004419 COMPOUND
2523 source_9 20211103-Run16 GR00004420 COMPOUND
2524 source_9 20211103-Run16 GR00004421 COMPOUND

2525 rows × 4 columns

jump_plates_metadata["Metadata_PlateType"].unique()
array(['COMPOUND_EMPTY', 'COMPOUND', 'DMSO', 'TARGET2', 'CRISPR', 'ORF',
       'TARGET1', 'POSCON8'], dtype=object)
jump_plates_metadata.groupby(["Metadata_Source", "Metadata_Batch"]).describe()
Metadata_Plate Metadata_PlateType
count unique top freq count unique top freq
Metadata_Source Metadata_Batch
source_1 Batch1_20221004 9 9 UL000109 1 9 2 COMPOUND 6
Batch2_20221006 7 7 UL001647 1 7 1 COMPOUND 7
Batch3_20221010 8 8 UL000087 1 8 1 COMPOUND 8
Batch4_20221012 8 8 UL000081 1 8 1 COMPOUND 8
Batch5_20221030 11 11 UL000561 1 11 2 COMPOUND 10
... ... ... ... ... ... ... ... ... ...
source_9 20210918-Run11 9 9 GR00004367 1 9 2 COMPOUND 8
20210918-Run12 8 8 GR00004377 1 8 1 COMPOUND 8
20211013-Run14 13 13 GR00003279 1 13 2 COMPOUND 12
20211102-Run15 11 11 GR00004391 1 11 2 COMPOUND 10
20211103-Run16 17 17 GR00004405 1 17 2 COMPOUND 16

149 rows × 8 columns

crispr_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_crispr_platemap.tsv", sep="\t")
crispr_wells_metadata["Plate_type"] = "CRISPR"
crispr_wells_metadata["Plate_label"] = 1
crispr_wells_metadata
well_position broad_sample Plate_type Plate_label
0 A01 BRDN0001480888 CRISPR 1
1 A02 BRDN0001483495 CRISPR 1
2 A03 BRDN0001147364 CRISPR 1
3 A04 BRDN0001490272 CRISPR 1
4 A05 BRDN0001480510 CRISPR 1
... ... ... ... ...
379 P20 BRDN0001145303 CRISPR 1
380 P21 BRDN0001484228 CRISPR 1
381 P22 BRDN0001487618 CRISPR 1
382 P23 BRDN0001487864 CRISPR 1
383 P24 BRDN0000735603 CRISPR 1

384 rows × 4 columns

orf_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_orf_platemap.tsv", sep="\t")
orf_wells_metadata["Plate_type"] = "ORF"
orf_wells_metadata["Plate_label"] = 2
orf_wells_metadata
well_position broad_sample Plate_type Plate_label
0 A01 ccsbBroad304_00900 ORF 2
1 A02 ccsbBroad304_07795 ORF 2
2 A03 ccsbBroad304_02826 ORF 2
3 A04 ccsbBroad304_01492 ORF 2
4 A05 ccsbBroad304_00691 ORF 2
... ... ... ... ...
379 P20 ccsbBroad304_00277 ORF 2
380 P21 ccsbBroad304_06464 ORF 2
381 P22 ccsbBroad304_00476 ORF 2
382 P23 ccsbBroad304_01649 ORF 2
383 P24 ccsbBroad304_03934 ORF 2

384 rows × 4 columns

compound_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_compound_platemap.tsv", sep="\t")
compound_wells_metadata["Plate_type"] = "COMPOUND"
compound_wells_metadata["Plate_label"] = 3
compound_wells_metadata
well_position broad_sample solvent Plate_type Plate_label
0 A01 BRD-A86665761-001-01-1 DMSO COMPOUND 3
1 A02 NaN DMSO COMPOUND 3
2 A03 BRD-A22032524-074-09-9 DMSO COMPOUND 3
3 A04 BRD-A01078468-001-14-8 DMSO COMPOUND 3
4 A05 BRD-K48278478-001-01-2 DMSO COMPOUND 3
... ... ... ... ... ...
379 P20 BRD-K68982262-001-01-4 DMSO COMPOUND 3
380 P21 BRD-K24616672-003-20-1 DMSO COMPOUND 3
381 P22 BRD-A82396632-008-30-8 DMSO COMPOUND 3
382 P23 BRD-K61250553-003-30-6 DMSO COMPOUND 3
383 P24 BRD-K70358946-001-17-3 DMSO COMPOUND 3

384 rows × 5 columns

wells_metadata = pd.concat([compound_wells_metadata, orf_wells_metadata, crispr_wells_metadata])
wells_metadata
well_position broad_sample solvent Plate_type Plate_label
0 A01 BRD-A86665761-001-01-1 DMSO COMPOUND 3
1 A02 NaN DMSO COMPOUND 3
2 A03 BRD-A22032524-074-09-9 DMSO COMPOUND 3
3 A04 BRD-A01078468-001-14-8 DMSO COMPOUND 3
4 A05 BRD-K48278478-001-01-2 DMSO COMPOUND 3
... ... ... ... ... ...
379 P20 BRDN0001145303 NaN CRISPR 1
380 P21 BRDN0001484228 NaN CRISPR 1
381 P22 BRDN0001487618 NaN CRISPR 1
382 P23 BRDN0001487864 NaN CRISPR 1
383 P24 BRDN0000735603 NaN CRISPR 1

1152 rows × 5 columns

wells_metadata.loc[wells_metadata["broad_sample"].isna(), "Plate_label"] = 0
wells_metadata
well_position broad_sample solvent Plate_type Plate_label
0 A01 BRD-A86665761-001-01-1 DMSO COMPOUND 3
1 A02 NaN DMSO COMPOUND 0
2 A03 BRD-A22032524-074-09-9 DMSO COMPOUND 3
3 A04 BRD-A01078468-001-14-8 DMSO COMPOUND 3
4 A05 BRD-K48278478-001-01-2 DMSO COMPOUND 3
... ... ... ... ... ...
379 P20 BRDN0001145303 NaN CRISPR 1
380 P21 BRDN0001484228 NaN CRISPR 1
381 P22 BRDN0001487618 NaN CRISPR 1
382 P23 BRDN0001487864 NaN CRISPR 1
383 P24 BRDN0000735603 NaN CRISPR 1

1152 rows × 5 columns

Review information related to each perturbation in the Broad Institute Genetic Perturbation Platform (https://portals.broadinstitute.org/gpp/public/)

Review the CP-JUMP data directly from the AWS bucket

The Cell Painting Image Collection Registry of Open Data on AWS (https://registry.opendata.aws/cellpainting-gallery/) is a collection of microscopy image sets.

The AWS bucket can be found here: https://cellpainting-gallery.s3.amazonaws.com/index.html

Get the URL of each assay plate from the bucket

import s3fs
fs = s3fs.S3FileSystem(anon=True)

batch_names = {}
plate_paths = {}
source_names = {}
plate_types = {}

for _, src_row in jump_plates_metadata.groupby(["Metadata_Source", "Metadata_Batch"]).describe().iterrows():
    source_name, batch_name = src_row.name

    # Ignore 'source_8' since the naming of the images is not standard
    if source_name in ["source_8"]:
        continue

    plate_type = src_row["Metadata_PlateType"].top

    for plate_path in fs.ls(f"cellpainting-gallery/cpg0016-jump/{source_name}/images/{batch_name}/images/"):
        plate_path = plate_path.split("/")[-1]
        if not plate_path:
            continue

        plate_name = plate_path.split("__")[0]

        source_names[plate_name] = source_name
        batch_names[plate_name] = batch_name
        plate_types[plate_name] = plate_type
        plate_paths[plate_name] = plate_path
plate_maps = pd.DataFrame()
plate_maps["Plate_name"] = batch_names.keys()
plate_maps["Source_name"] = plate_maps["Plate_name"].map(source_names)
plate_maps["Batch_name"] = plate_maps["Plate_name"].map(batch_names)
plate_maps["Plate_type"] = plate_maps["Plate_name"].map(plate_types)
plate_maps["Plate_path"] = plate_maps["Plate_name"].map(plate_paths)
plate_maps
Plate_name Source_name Batch_name Plate_type Plate_path
0 UL000109 source_1 Batch1_20221004 COMPOUND UL000109__2022-10-05T06_35_06-Measurement1
1 UL001641 source_1 Batch1_20221004 COMPOUND UL001641__2022-10-04T23_16_28-Measurement1
2 UL001643 source_1 Batch1_20221004 COMPOUND UL001643__2022-10-04T18_52_42-Measurement2
3 UL001645 source_1 Batch1_20221004 COMPOUND UL001645__2022-10-05T00_44_11-Measurement1
4 UL001651 source_1 Batch1_20221004 COMPOUND UL001651__2022-10-04T20_20_52-Measurement1
... ... ... ... ... ...
2333 GR00004417 source_9 20211103-Run16 COMPOUND GR00004417
2334 GR00004418 source_9 20211103-Run16 COMPOUND GR00004418
2335 GR00004419 source_9 20211103-Run16 COMPOUND GR00004419
2336 GR00004420 source_9 20211103-Run16 COMPOUND GR00004420
2337 GR00004421 source_9 20211103-Run16 COMPOUND GR00004421

2338 rows × 5 columns

comp_plate_maps = plate_maps.query("Plate_type=='COMPOUND'")
comp_plate_maps
Plate_name Source_name Batch_name Plate_type Plate_path
0 UL000109 source_1 Batch1_20221004 COMPOUND UL000109__2022-10-05T06_35_06-Measurement1
1 UL001641 source_1 Batch1_20221004 COMPOUND UL001641__2022-10-04T23_16_28-Measurement1
2 UL001643 source_1 Batch1_20221004 COMPOUND UL001643__2022-10-04T18_52_42-Measurement2
3 UL001645 source_1 Batch1_20221004 COMPOUND UL001645__2022-10-05T00_44_11-Measurement1
4 UL001651 source_1 Batch1_20221004 COMPOUND UL001651__2022-10-04T20_20_52-Measurement1
... ... ... ... ... ...
2333 GR00004417 source_9 20211103-Run16 COMPOUND GR00004417
2334 GR00004418 source_9 20211103-Run16 COMPOUND GR00004418
2335 GR00004419 source_9 20211103-Run16 COMPOUND GR00004419
2336 GR00004420 source_9 20211103-Run16 COMPOUND GR00004420
2337 GR00004421 source_9 20211103-Run16 COMPOUND GR00004421

1905 rows × 5 columns

pert_plate_maps = plate_maps[plate_maps["Plate_type"].isin(["CRISPR", "ORF", "DMSO"])]
pert_plate_maps
Plate_name Source_name Batch_name Plate_type Plate_path
142 Dest210628-161651 source_10 2021_06_28_U2OS_48_hr_run9 DMSO Dest210628-161651
143 Dest210628-162003 source_10 2021_06_28_U2OS_48_hr_run9 DMSO Dest210628-162003
457 CP-CC9-R1-01 source_13 20220914_Run1 CRISPR CP-CC9-R1-01
458 CP-CC9-R1-02 source_13 20220914_Run1 CRISPR CP-CC9-R1-02
459 CP-CC9-R1-03 source_13 20220914_Run1 CRISPR CP-CC9-R1-03
... ... ... ... ... ...
1591 BR00127145 source_4 2021_08_30_Batch13 ORF BR00127145__2021-09-22T04_01_46-Measurement1
1592 BR00127146 source_4 2021_08_30_Batch13 ORF BR00127146__2021-09-22T12_25_07-Measurement1
1593 BR00127147 source_4 2021_08_30_Batch13 ORF BR00127147__2021-09-18T10_27_12-Measurement1
1594 BR00127148 source_4 2021_08_30_Batch13 ORF BR00127148__2021-09-21T11_44_23-Measurement1
1595 BR00127149 source_4 2021_08_30_Batch13 ORF BR00127149__2021-09-18T02_10_04-Measurement1

433 rows × 5 columns

pert_plate_maps["Source_name"].unique()
array(['source_10', 'source_13', 'source_4'], dtype=object)
comp_plate_maps["Source_name"].unique()
array(['source_1', 'source_10', 'source_11', 'source_15', 'source_2',
       'source_3', 'source_5', 'source_6', 'source_7', 'source_9'],
      dtype=object)

Split the dataset into Training, Validation, and Test sets

import random
import math
trn_plates = []
val_plates = []
tst_plates = []

trn_proportion = 0.7
val_proportion = 0.2
tst_proportion = 0.1

for batch_name in pert_plate_maps["Batch_name"].unique():
    plate_names = pert_plate_maps.query(f"Batch_name == '{batch_name}'")["Plate_name"].tolist()
    random.shuffle(plate_names)

    tst_plates_count = int(math.ceil(len(plate_names) * tst_proportion))
    val_plates_count = int(math.ceil(len(plate_names) * val_proportion))

    tst_plates += plate_names[:tst_plates_count]
    val_plates += plate_names[tst_plates_count:tst_plates_count + val_plates_count]
    trn_plates += plate_names[tst_plates_count + val_plates_count:]
trn_plates[:5]
['CP-CC9-R1-16',
 'CP-CC9-R1-22',
 'CP-CC9-R1-26',
 'CP-CC9-R1-12',
 'CP-CC9-R1-29']
val_plates[:5]
['Dest210628-162003',
 'CP-CC9-R1-13',
 'CP-CC9-R1-28',
 'CP-CC9-R1-17',
 'CP-CC9-R1-24']
tst_plates[:5]
['Dest210628-161651',
 'CP-CC9-R1-08',
 'CP-CC9-R1-01',
 'CP-CC9-R1-07',
 'CP-CC9-R2-22']
print("Training set size:", len(trn_plates))
print("Validation set size:", len(val_plates))
print("Testing set size:", len(tst_plates))
Training set size: 283
Validation set size: 96
Testing set size: 54

Create a Dataset that can be used with PyTorch

# @title Definition of a Dataset class capable to pull images from AWS S3 buckets
import random
import numpy as np
import string
import s3fs

from itertools import product

from PIL import Image
import tifffile

from torch.utils.data import IterableDataset, get_worker_info


from time import perf_counter


def s3dataset_worker_init_fn(worker_id):
    """ZarrDataset multithread workers initialization function.
    """
    worker_info = torch.utils.data.get_worker_info()
    w_sel = slice(worker_id, None, worker_info.num_workers)

    dataset_obj = worker_info.dataset

    # Reset the random number generators in each worker.
    torch_seed = torch.initial_seed()

    dataset_obj._worker_sel = w_sel
    dataset_obj._worker_id = worker_id
    dataset_obj._num_workers = worker_info.num_workers


def load_well(plate_metadata, well_row, well_col, field_id, channels, s3):
    # Get the label of the current well
    curr_well_image = []

    plate_path = "cellpainting-gallery/cpg0016-jump/" + plate_metadata["Source_name"] + "/images/" + plate_metadata["Batch_name"] + "/images/" + plate_metadata["Plate_path"]

    for channel_id in range(channels):
        if plate_metadata["Source_name"] in ["source_1", "source_3", "source_4", "source_9", "source_11", "source_15"]:
            image_suffix = f"Images/r{well_row + 1:02d}c{well_col + 1:02d}f{field_id + 1:02d}p01-ch{channel_id + 1}sk1fk1fl1.tiff"

        else:
            if plate_metadata["Source_name"] in ["source_2", "source_5"]:
                a_locs = [1, 2, 3, 4, 5]
            elif plate_metadata["Source_name"] in ["source_6", "source_10"]:
                a_locs = [1, 2, 2, 3, 1, 4]
            elif plate_metadata["Source_name"] in ["source_7", "source_13"]:
                a_locs = [1, 1, 2, 3, 4]

            image_suffix = f"{plate_metadata["Plate_name"]}_{string.ascii_uppercase[well_row]}{well_col + 1:02d}_T0001F{field_id + 1:03d}L01A{a_locs[channel_id]:02d}Z01C{channel_id + 1:02d}.tif"

        image_url = "s3://" + plate_path + "/" + image_suffix

        try:
            with s3.open(image_url, 'rb') as f:
                curr_image = tifffile.imread(f)

        except FileNotFoundError:
            print("Failed retrieving:", image_url)
            return None

        curr_image = curr_image.astype(np.float32)
        curr_image /= 2 ** 16 - 1

        curr_well_image.append(curr_image)

    curr_well_image = np.array(curr_well_image)

    return curr_well_image


class TiffS3Dataset(IterableDataset):
    """This dataset could have virtually infinite samples.
    """
    def __init__(self, plate_maps, wells_metadata, plate_names, well_rows=24, well_cols=16, fields=4, channels=5, shuffle=False):
        super(TiffS3Dataset).__init__()

        self._plate_maps = plate_maps
        self._wells_metadata = wells_metadata

        self._plate_names = plate_names
        self._well_rows = well_rows
        self._well_cols = well_cols
        self._fields = fields
        self._channels = channels

        self._shuffle = shuffle

        self._worker_sel = slice(0, len(plate_names) * self._well_rows * self._well_cols)
        self._worker_id = 0
        self._num_workers = 1

        self._s3 = None

    def __iter__(self):
        # Select the barcodes that correspond to this worker
        self._s3 = s3fs.S3FileSystem(anon=True)

        self._plate_names = self._plate_names[self._worker_sel]

        well_row_range = range(self._well_rows)
        well_col_range = range(self._well_cols)
        fields_range = range(self._fields)

        for plate_name, well_row, well_col, field_id in product(self._plate_names, well_row_range, well_col_range, fields_range):
            if self._shuffle:
                plate_name = random.choice(self._plate_names)
                well_row = random.randrange(self._well_rows)
                well_col = random.randrange(self._well_cols)
                field_id = random.randrange(self._fields)

            curr_plate_map = self._plate_maps.query(f"Plate_name == '{plate_name}'")

            curr_plate_metadata = curr_plate_map.to_dict(orient='records')[0]

            if not len(curr_plate_metadata):
                continue

            curr_image = load_well(curr_plate_metadata, well_row, well_col, field_id, self._channels, self._s3)

            if curr_image is None:
                continue

            curr_image = curr_image[:, :1080, :1080]
            _, h, w = curr_image.shape
            pad_h = 1080 - h
            pad_w = 1080 - w

            if pad_h or pad_w:
                curr_image = np.pad(curr_image, ((0, 0), (0, pad_h), (0, pad_w)))
            
            if curr_plate_metadata["Plate_type"] == "DMSO":
                curr_label = 0

            else:
                curr_label = self._wells_metadata.query(f"Plate_type=='{curr_plate_metadata["Plate_type"]}' & well_position=='{string.ascii_uppercase[well_row]}{well_col + 1:02d}'")["Plate_label"]

                if not len(curr_label):
                    continue

                curr_label = curr_label.item()

            yield curr_image, curr_label, curr_plate_metadata

        self._s3 = None

Create the datasets from the list of URLs

training_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, trn_plates, 16, 24, 9, 5, shuffle=True)
validation_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, val_plates, 16, 24, 9, 5, shuffle=True)
testing_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, tst_plates, 16, 24, 9, 5, shuffle=True)

Import a pre-trained model from torchvision

import torch
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
weights = MobileNet_V3_Small_Weights.DEFAULT
model = mobilenet_v3_small(weights=weights)

Change the last layers of the pre-trained model to convert it into a feature extraction function

org_avgpool = model.avgpool
model.avgpool = torch.nn.Identity()
model.classifier = torch.nn.Identity()

model.cuda()
model.eval()
MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=72, bias=False)
          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (3): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 88, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(88, 88, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=88, bias=False)
          (1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(88, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (4): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(96, 96, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=96, bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(96, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (5): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (6): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (7): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(120, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (8): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(40, 144, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(144, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (9): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(288, 288, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=288, bias=False)
          (1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(288, 72, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(72, 288, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (10): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (11): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (12): Conv2dNormActivation(
      (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
  )
  (avgpool): Identity()
  (classifier): Identity()
)
model_transforms = weights.transforms()
model_transforms
ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

Create a torch DataLoader to train PyTorch models

from tqdm.auto import tqdm
from torch.utils.data.dataloader import DataLoader
batch_size = 100

training_dl = DataLoader(training_ds, batch_size=batch_size, num_workers=8, worker_init_fn=s3dataset_worker_init_fn)
features = []
targets = []

for i, (x, y, _) in tqdm(enumerate(training_dl)):
    if i >= 1000:
        break

    b, c, h, w = x.shape
    x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))

    if torch.cuda.is_available():
        x_t = x_t.cuda()
    
    with torch.no_grad():
        x_out = model(x_t)
        x_out = x_out.detach().cpu().reshape(-1, c, 576, 7, 7).sum(dim=1)
        x_out = org_avgpool(x_out).detach().reshape(b, -1)

    features.append(x_out)
    targets.append(y)

    if (i + 1) % 100 == 0:
        features = torch.cat(features, dim=0)

        # The labels are mapped as NONE/DMSO = 0, ORF = 1, CRISPS = 2, and COMPUND = 3
        targets = torch.cat(targets, dim=0)

        torch.save(dict(features=features, targets=targets), f"trn_features_{i // 100:03d}.pt")
        print("Saved features checkpoint", f"trn_features_{i // 100:03d}.pt", features.shape, targets.shape)

        features = []
        targets = []
Saved features checkpoint trn_features_000.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_001.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_002.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_003.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_004.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_005.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_006.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_007.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_008.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_009.pt torch.Size([10000, 576]) torch.Size([10000])
val_features = []
val_targets = []

validation_dl = DataLoader(validation_ds, batch_size=100, num_workers=8, worker_init_fn=s3dataset_worker_init_fn)

for i, (x, y, _) in tqdm(enumerate(validation_dl)):
    if i >= 50:
        break

    b, c, h, w = x.shape
    x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))

    if torch.cuda.is_available():
        x_t = x_t.cuda()

    with torch.no_grad():
        x_out = model(x_t)
        x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
        x_out = org_avgpool(x_out).detach().reshape(b, -1)

    val_features.append(x_out)
    val_targets.append(y)

val_features = torch.cat(val_features, dim=0)

# The labels are mapped as NONE/DMSO = 0, ORF = 1, CRISPS = 2, and COMPUND = 3
val_targets = torch.cat(val_targets, dim=0)
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^
^Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^Exception ignored in:     ^self._shutdown_workers()^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

^Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    if w.is_alive():      File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive

self._shutdown_workers()
       File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
assert self._parent_pid == os.getpid(), 'can only test a child process' 
     if w.is_alive(): 
             ^  ^ ^  ^ ^^ ^ ^^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    ^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^^ 
^   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^     ^ assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^
^   ^  ^ ^  ^  ^ ^^ ^ ^^^ ^ ^^ ^^ ^^^^^
^^AssertionError^^^: ^^can only test a child process
^^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^
^^Traceback (most recent call last):
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^^self._shutdown_workers()^^
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^    ^^if w.is_alive():^^
^ ^^^ ^^ ^^^ 
^ AssertionError^ : ^ can only test a child process^^
^^
^AssertionError^: Exception ignored in: ^can only test a child process<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^Exception ignored in: self._shutdown_workers()^
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():    

self._shutdown_workers()  
    File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
         if w.is_alive():  
      ^ ^  ^  ^^^ ^ ^^ ^^^^^^^^^^^^^^^^
^^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    ^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^ ^^ ^^ 
 ^   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^     ^ assert self._parent_pid == os.getpid(), 'can only test a child process' ^
  ^  ^^^ ^^ ^^ ^ ^ ^^ ^^^^ ^^^ ^
^ AssertionError^^: ^^^can only test a child process^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^AssertionError^: ^can only test a child process^
^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>

AssertionErrorTraceback (most recent call last):
:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
can only test a child process
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>if w.is_alive():

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
       if w.is_alive():^
^  ^ ^ ^ ^ ^^ ^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    assert self._parent_pid == os.getpid(), 'can only test a child process'^^
           ^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    ^^assert self._parent_pid == os.getpid(), 'can only test a child process'
^ ^ ^ ^ ^ ^ ^^   ^^^ ^ ^^^^^^^^^^^^^^^^^
AssertionError: ^can only test a child process^
^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_G23_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_L14_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_C06_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_N23_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_P12_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_I11_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_J03_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_F11_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_M11_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_J16_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_C18_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_P06_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_D18_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_K21_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_L12_T0001F009L01A01Z01C01.tif
torch.save(dict(features=val_features, targets=val_targets), "val_features.pt")
tst_features = []
tst_targets = []

testing_dl = DataLoader(testing_ds, batch_size=batch_size, num_workers=8, worker_init_fn=s3dataset_worker_init_fn)

for i, (x, y, _) in tqdm(enumerate(testing_dl)):
    if i >= 50:
        break

    b, c, h, w = x.shape
    x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))

    if torch.cuda.is_available():
        x_t = x_t.cuda()

    with torch.no_grad():
        x_out = model(x_t)
        x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
        x_out = org_avgpool(x_out).detach().reshape(b, -1)

    tst_features.append(x_out)
    tst_targets.append(y)

tst_features = torch.cat(tst_features, dim=0)

# The labels are mapped as NONE/DMSO = 0, ORF = 1, CRISPS = 2, and COMPUND = 3
tst_targets = torch.cat(tst_targets, dim=0)
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Exception ignored in: Exception ignored in: Exception ignored in: 
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>Exception ignored in:         Traceback (most recent call last):


<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>self._shutdown_workers()self._shutdown_workers()Exception ignored in: Exception ignored in: Traceback (most recent call last):


Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>    
    
            self._shutdown_workers()if w.is_alive():self._shutdown_workers()Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
if w.is_alive():self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__



  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive

       File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      AssertionError  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
 self._shutdown_workers()self._shutdown_workers()         : 
     if w.is_alive():
if w.is_alive(): if w.is_alive(): can only test a child process  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers




               if w.is_alive():if w.is_alive():  ^

    ^  ^  ^ ^     ^^   ^ ^   ^   ^^^^ ^^ ^^  ^^^^ ^  ^^^^^^^^ ^^^^^^^^^
^^^^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^^^    ^^^^^
assert self._parent_pid == os.getpid(), 'can only test a child process'^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 ^^^^^^^
^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^

    ^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    
assert self._parent_pid == os.getpid(), 'can only test a child process'^    
 
  ^assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 
          ^  ^assert self._parent_pid == os.getpid(), 'can only test a child process' ^ 
                      assert self._parent_pid == os.getpid(), 'can only test a child process'  
    ^^ ^   ^ ^ ^  ^   ^  ^  ^ ^^^^^^
^^^  ^ ^ ^ ^ ^ ^^^^^^ ^ ^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^^^ ^^^^^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^ ^^^^^^^^
^^^^^^^^ ^^^^ ^^^^^ ^^^^^^^^^^^^^^^ ^^^^^ ^^^^^ ^^^^^ ^^^^^^ ^^^^^^^^^^^^^^^ ^^^ ^ ^^^^^^^^^^^^^^^^^^^^
^^^AssertionError^^^^^^: ^^^can only test a child process^^^
^^AssertionError^^^^: ^^^^can only test a child process^
^
^^
^^^^^AssertionError^^^: 
^^^^AssertionErrorcan only test a child process^: ^^can only test a child process
^^
^^^^^^^
AssertionError: 
can only test a child process
^AssertionError: ^can only test a child process^
^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
    Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
 Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>     ^
Traceback (most recent call last):
self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^
    ^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^Exception ignored in:     ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>self._shutdown_workers()if w.is_alive():

^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ ^ ^    ^ 
if w.is_alive(): Traceback (most recent call last):
^
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:  
  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>      File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
     self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only test a child process' 
^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

^  Traceback (most recent call last):
      ^   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^if w.is_alive():^
       ^self._shutdown_workers() ^Exception ignored in:   
^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
 ^
^  Traceback (most recent call last):
^     if w.is_alive(): ^ 
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  ^ ^  ^     
^ self._shutdown_workers()^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^ ^
^    ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ ^assert self._parent_pid == os.getpid(), 'can only test a child process'    ^ ^^^
^if w.is_alive(): ^^^ ^
 ^
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^  ^      ^^^^assert self._parent_pid == os.getpid(), 'can only test a child process' ^ ^^
 ^^^^ ^^
^  ^   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 ^     ^^   assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^ ^ ^^
^^^ ^^^ ^^^ ^ ^^ ^^ ^^ 
^^^ Exception ignored in: ^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
  ^  ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>      ^^ ^^assert self._parent_pid == os.getpid(), 'can only test a child process'
 ^
^^ ^Traceback (most recent call last):
^^^  ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^^ ^^    ^ ^^^self._shutdown_workers()
 ^^ ^AssertionError^
^: ^^ ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ can only test a child process
^   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive

^    ^     ^^ ^assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():^^
^
 ^^^  ^^^^ ^^^^ ^ ^^  ^^^Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^ ^
^ ^^Traceback (most recent call last):
^^ ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^  ^ ^    ^ ^^self._shutdown_workers() ^^
^^^ ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^^^^^^^^^    ^^^^^^^^^if w.is_alive():^^^^^
^^ ^^^^^ ^^^^^ ^^^^^ ^^^^^^ ^^^^^
^ ^^^^ AssertionError
^^^AssertionError^
^^: : ^^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^can only test a child process^can only test a child process^
^^^    
^^assert self._parent_pid == os.getpid(), 'can only test a child process'
^^Exception ignored in: ^^^^ ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^^ Exception ignored in: ^^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^ ^ ^^

^AssertionErrorTraceback (most recent call last):
 : 
^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    Traceback (most recent call last):
^     can only test a child process^^ self._shutdown_workers()
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

 ^^

     ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
AssertionError^self._shutdown_workers()Exception ignored in:  ^Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>    
^if w.is_alive():^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
:  

     Traceback (most recent call last):
can only test a child process  ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

 Traceback (most recent call last):
       File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^if w.is_alive():self._shutdown_workers()^       ^
 self._shutdown_workers()  
Exception ignored in: ^
 ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
 ^^^ ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^    ^
      ^ if w.is_alive():AssertionErrorif w.is_alive(): ^
 ^ ^: 
Traceback (most recent call last):

^^can only test a child process  ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
 ^  ^    ^^
 ^^ ^self._shutdown_workers()^ ^^
 ^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ ^^^ Exception ignored in:     ^^^  ^^^if w.is_alive(): ^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^^^^
^ ^^ ^^^^ 
 ^^^^Traceback (most recent call last):
 ^^^^   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^ ^^^    ^^^^ ^self._shutdown_workers()^^ ^^^^^
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

^^^^^    ^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^if w.is_alive():^^^^

^^    ^
^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 ^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^^    
^^     ^^ assert self._parent_pid == os.getpid(), 'can only test a child process'assert self._parent_pid == os.getpid(), 'can only test a child process'^^^ ^ 

 ^
 ^  ^   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive

^  ^AssertionError^      : ^^ assert self._parent_pid == os.getpid(), 'can only test a child process'   
^ 
can only test a child process   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive

   ^  ^      ^   assert self._parent_pid == os.getpid(), 'can only test a child process'   ^  ^ 
Exception ignored in:   
 ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>   AssertionError^ 
 ^ :  ^   ^can only test a child process Traceback (most recent call last):
^ ^^
  ^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^ ^^^^ ^^^^ ^^    ^^^Exception ignored in: ^ ^
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>self._shutdown_workers()^^^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^
^
^ Traceback (most recent call last):
^^^       File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^assert self._parent_pid == os.getpid(), 'can only test a child process'^^      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^
^ ^^    ^^^if w.is_alive():  ^self._shutdown_workers()
^ ^
^^^ ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
 ^^^     ^^ if w.is_alive():^ ^^ ^^^
^ ^ ^  ^ ^^^   ^^^^ ^^^ ^^^ ^^^^ ^^ ^^^ ^^^^ ^^^ ^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^AssertionError^^^^^
: ^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^can only test a child process^^    ^^

^^^AssertionError^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^: 

^^can only test a child process^^

^AssertionError
^^: Exception ignored in: AssertionError   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>can only test a child process^ can only test a child process
    
assert self._parent_pid == os.getpid(), 'can only test a child process'
 Exception ignored in:  ^
^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^
^   Exception ignored in:  Traceback (most recent call last):
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    AssertionError^
  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>    : Traceback (most recent call last):
self._shutdown_workers()
 ^   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^can only test a child process 
self._shutdown_workers() ^Traceback (most recent call last):
 

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      ^^    if w.is_alive():    ^^if w.is_alive():^
^self._shutdown_workers()^^ 
Exception ignored in: ^^    <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
 self._shutdown_workers() ^
^^   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Traceback (most recent call last):
 ^ ^ 
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      ^^AssertionError      ^if w.is_alive():^: 
 self._shutdown_workers() 
^^can only test a child process  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^
  ^    
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^if w.is_alive(): ^^^
 ^    ^Exception ignored in: ^ ^^^if w.is_alive():  ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^ 
^^^ 
Traceback (most recent call last):
^^ ^  ^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
 ^^^^     ^^^^^ ^^  ^ ^^^  ^^ ^^^^^ 
^^^^^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^self._shutdown_workers()^    ^^^^
^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      ^^^if w.is_alive():^^^
 ^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^  ^^      ^^^^^^ 
^ ^assert self._parent_pid == os.getpid(), 'can only test a child process'  ^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^
^^ ^^^^  ^^      ^assert self._parent_pid == os.getpid(), 'can only test a child process'^   ^
^^^  ^^ ^
^  ^
 ^^AssertionError 
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 ^ ^
       File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
: ^ AssertionError^^assert self._parent_pid == os.getpid(), 'can only test a child process'can only test a child process : ^  ^
can only test a child process ^    
^
 ^assert self._parent_pid == os.getpid(), 'can only test a child process'^    ^^ ^^
^^   ^ ^^ ^ Exception ignored in:  ^^^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^ 
  ^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^
^ ^     ^Traceback (most recent call last):
 ^assert self._parent_pid == os.getpid(), 'can only test a child process'^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^ ^
^ ^^     ^^ ^  ^self._shutdown_workers()  ^^^ ^ 
^^^^ ^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
 ^^    ^^^^ ^^^if w.is_alive():^^ ^
^^^ ^ ^^^^ ^^  ^^^^ ^ ^^ ^^^ ^^ ^  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^AssertionError
^: ^^^AssertionError^^: ^can only test a child process^^^^^^can only test a child process
^^^^
^^^^^^^^^
^^^^^^^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^    ^^
^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^
AssertionError^: ^ ^
can only test a child process AssertionError
 ^ ^: can only test a child process^^
^ ^^^ ^ ^Exception ignored in: 
 <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^
AssertionError ^Traceback (most recent call last):
: ^ ^   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
can only test a child process^^    ^
self._shutdown_workers()^^^
^^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

    if w.is_alive():^AssertionError
: ^can only test a child process
  ^     ^^^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^  ^^ ^^ ^ ^  ^ ^ ^
  AssertionError^: can only test a child process^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L23_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L03_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_O07_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_N12_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_G18_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I07_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_P09_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L10_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_C17_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I05_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_A08_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K03_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_N16_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_E02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K12_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_F24_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_J01_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_A17_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L17_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_B16_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_F12_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_O08_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_D17_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_M21_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_D20_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_D22_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_M21_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_C20_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K03_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_M24_T0001F007L01A01Z01C01.tif
torch.save(dict(features=tst_features, targets=tst_targets), "tst_features.pt")

Set up the model training as an optimization problem

classifier = torch.nn.Sequential(
    torch.nn.Linear(in_features=576, out_features=2, bias=False),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=2, out_features=3, bias=False)
)
if torch.cuda.is_available():
    classifier.cuda()
optimizer = torch.optim.SGD(classifier.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
trn_feat_dl = DataLoader(list(zip(features, targets)), batch_size=100, shuffle=True)
val_feat_dl = DataLoader(list(zip(val_features, val_targets)), batch_size=100, shuffle=False)
avg_loss_trn = []
avg_acc_trn = []
avg_loss_val = []
avg_acc_val = []

n_epochs = 100
q = tqdm(total=n_epochs)

for e in range(n_epochs):
    n_dmso = 0
    n_crispr = 0
    n_orf = 0

    # Training loop
    classifier.train()

    loss_epoch = 0
    acc_epoch = 0
    for x, y in trn_feat_dl:
        optimizer.zero_grad()

        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        y_pred = classifier(x.squeeze())
        loss = loss_fn(y_pred, y)

        loss.backward()

        optimizer.step()

        loss_epoch += loss.item()
        acc_epoch += torch.sum(y_pred.argmax(dim=1) == y) / len(y)

        n_dmso += sum(y == 0)
        n_crispr += sum(y == 1)
        n_orf += sum(y == 2)

    avg_loss_trn.append(loss_epoch / len(trn_feat_dl))
    avg_acc_trn.append(acc_epoch / len(trn_feat_dl))

    n_total = n_dmso + n_crispr + n_orf
    trn_class_props = [n_dmso / n_total, n_crispr / n_total, n_orf / n_total]

    # Validation loop
    classifier.eval()

    n_dmso = 0
    n_crispr = 0
    n_orf = 0

    loss_epoch = 0
    acc_epoch = 0
    for x_val, y_val in val_feat_dl:
        with torch.no_grad():

            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()

            y_val_pred = classifier(x_val.squeeze())
            loss = loss_fn(y_val_pred, y_val)

        loss_epoch += loss.item()
        acc_epoch += torch.sum(y_val_pred.argmax(dim=1) == y_val) / len(y_val)

        n_dmso += sum(y == 0)
        n_crispr += sum(y == 1)
        n_orf += sum(y == 2)

    avg_loss_val.append(loss_epoch / len(val_feat_dl))
    avg_acc_val.append(acc_epoch / len(val_feat_dl))

    n_total = n_dmso + n_crispr + n_orf
    val_class_props = [n_dmso / n_total, n_crispr / n_total, n_orf / n_total]

    q.set_description(f"Average training loss: {avg_loss_trn[-1]:0.4f} (Accuracy: {100 * avg_acc_trn[-1]:0.4f} %). Average validation loss: {avg_loss_val[-1]:04f} (Accuracy: {100 * avg_acc_val[-1]:0.4f} %)")
    q.update()
trn_class_props, val_class_props
import matplotlib.pyplot as plt
plt.plot(avg_loss_trn, "k-", label="Training loss")
plt.plot(avg_loss_val, "b:", label="Validation loss")
plt.legend()
plt.plot(avg_acc_trn, "k-", label="Training accuracy")
plt.plot(avg_acc_val, "b:", label="Validation accuracy")
plt.legend()
trn_feat_dl = DataLoader(list(zip(features, targets)), batch_size=2000, shuffle=True)
val_feat_dl = DataLoader(list(zip(val_features, val_targets)), batch_size=2000, shuffle=True)

x_trn, y_trn = next(iter(trn_feat_dl))
x_val, y_val = next(iter(val_feat_dl))

classifier.eval()
with torch.no_grad():
    fx_trn = classifier[0](x_trn)
    fx_val = classifier[0](x_val)
fx_trn.shape
class_names = ["NONE/DMSO", "CRISPR", "ORF", "COMPUND"]
markers = ['o', 's', '^', 'v']
for y_idx, class_name in enumerate(class_names):
    plt.scatter(x=fx_trn[y_trn == y_idx, 0], y=fx_trn[y_trn == y_idx, 1], marker=markers[y_idx], label=class_name)

plt.legend()
plt.show()
for y_idx, class_name in enumerate(class_names):
    plt.scatter(x=fx_val[y_val == y_idx, 0], y=fx_val[y_val == y_idx, 1], marker=markers[y_idx], label=class_name)

plt.legend()
plt.show()

Check what compounds are similar according to their phenotipic profile

s3 = s3fs.S3FileSystem(anon=True)
comp_plate_maps
compound_wells_metadata
comp_plate_map = comp_plate_maps.iloc[[0]]
comp_plate_map

Load the image from the AWS bucket

x_comp, y_comp = load_well(comp_plate_map, wells_metadata, 1, 1, 0, 5, s3)

Add a dummy axis to treat a single sample as a batch of size one

x_comp = torch.from_numpy(x_comp[None, ...])
x_comp.shape, x_comp.dtype, y_comp

Extract features with the baseline model

b, c, h, w = x_comp.shape
x_comp_t = model_transforms(torch.tile(x_comp.reshape(-1, 1, h, w), (1, 3, 1, 1)))

with torch.no_grad():
    x_out = model(x_comp_t)
    x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
    x_out = org_avgpool(x_out).detach().reshape(b, -1)

Predict the type of perturbation with the classifier model

classifier.eval()
with torch.no_grad():
    y_pred_comp = classifier(x_out)
    fx_comp = classifier[0](x_out)
y_pred_comp.argmax(), class_names[y_pred_comp.argmax().item()],  class_names[y_comp]
markers = ['o', 's', '^', 'v']
for y_idx, class_name in enumerate(class_names):
    plt.scatter(x=fx_trn[y_trn == y_idx, 0], y=fx_trn[y_trn == y_idx, 1], marker=markers[y_idx], label=class_name)

plt.scatter(x=fx_comp[0, 0], y=fx_comp[0, 1], marker="x", label="Test")

plt.legend()
plt.show()